[Feature][Performance] NextObservationDelta env transform#3777
Conversation
Adds a stateless env-side transform that stores `("next", obs)` as a
low-precision delta from the root `obs`, reducing the rollout-time
memory footprint of large continuous observations.
The transform compresses next observations in `_step` and rehydrates
the flowing tensordict's root observation in a new
`_post_step_mdp_hooks` extension point on `EnvBase`. The hook was
previously half-stubbed in `common.py` / `_base.py` / `llm/chat.py`;
it is now wired through `step_and_maybe_reset` and threaded into
`Transform`, `Compose`, and `TransformedEnv`.
Caveats documented on the class:
- The compression is lossy; round-trip error scales with delta dtype
precision and observation magnitude.
- Memory savings only materialize against non-pre-allocated stacked
output (e.g. `SyncDataCollector(use_buffers=False)` or a lazy RB
storage). Pre-allocated buffers upcast the write.
- The hook fires from `step_and_maybe_reset`; direct `env.rollout()`
callers must rehydrate manually.
- `check_env_specs` rejects the transformed env in v1 because the
observation spec is shared between root and `("next", ...)` and we
do not fork it.
Includes a `TestNextObservationDelta` test class with 16 cases
(14 passing, 2 documented skips) covering single-env, serial/parallel
batched envs (inner and outer wrapping), auto-inference skipping
non-floating dtypes, multi-key, reset semantics, Compose ordering,
and an end-to-end `SyncDataCollector(use_buffers=False)` check that
the stacked batch carries `float16` `("next", obs)`.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3777
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New FailuresAs of commit 0ef4983 with merge base 996387f ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
- Wire `_post_step_mdp_hooks` in `EnvBase._rollout_stop_early` so
`env.rollout(..., break_when_any_done=True)` rehydrates the flowing
td just like `step_and_maybe_reset` already did. The non-stop path
already routed through `step_and_maybe_reset` and is unchanged.
- Add `Transform.transform_fake_tensordict(td)` hook (no-op default),
iterated by `Compose`, called by a new `TransformedEnv.fake_tensordict`
override. `NextObservationDelta` overrides it to cast the
`("next", key)` leaves to `delta_dtype` in the spec-derived fake td.
Pre-allocated `_final_rollout` in `SyncDataCollector(use_buffers=True)`
now reserves storage at the compressed dtype rather than upcasting
writes; the collector test covers both `use_buffers={True, False}`.
- Add `Transform._check_batched_worker_compat()` (no-op default).
`NextObservationDelta` raises with a clear message pointing at the
correct usage pattern. `BatchedEnvBase._get_metadata` builds a
transient probe env and runs the validator via a new `env_validator`
kwarg on `get_env_metadata`, so the inner-batched configuration
fails loudly at construction time instead of silently upcasting at
runtime.
The remaining v1 caveat in the docstring is that `check_env_specs`
still does not pass: it calls `observation_spec.contains(("next", obs))`
and TorchRL shares `observation_spec` between root and `("next", ...)`
leaves, so a compressed dtype is rejected. Working around this
properly requires forking the spec system, which is out of scope for
this PR. Tests use a reset+step smoke instead.
Subtracting in delta_dtype (float16 by default) risks catastrophic cancellation when next_obs and obs are close. Doing the subtraction in the operands' source dtype and casting the result once preserves significand bits and is strictly more accurate on round-trip. The stored root obs is unchanged, so there is no asymmetry to preserve between the on-the-fly delta and the value reconstructed from storage.
|
@elin-bdai @theap06 maybe you could help review this one? |
elin-bdai
left a comment
There was a problem hiding this comment.
Thanks for doing this! I'm going to test this with our longer training jobs this week to make sure loss of precision is not a problem. Just a comment in terms of reducing confusion.
| >>> td_root = env.reset() | ||
| >>> _ = td_root.set("action", env.action_spec.rand()) | ||
| >>> td, td_ = env.step_and_maybe_reset(td_root) | ||
| >>> td["next", "observation"].dtype |
There was a problem hiding this comment.
If I'm understanding correctly, I think it's confusing here when using NextObservationDelta() that what's inside td["next", "observation"] is the delta, but the tensordict is indistinguishable from when you don't use NextObservationDelta, so you're not sure if it's the delta or not stored in there. It could lead to confusion when inspecting the outputs at different points.
| # operand to ``delta_dtype`` first and subtracting in low precision | ||
| # (which would risk catastrophic cancellation for nearby values). | ||
| delta = (next_obs - obs).to(self.delta_dtype) | ||
| next_tensordict.set(key, delta) |
There was a problem hiding this comment.
Would it make sense to change the key here to {key}_delta?
Summary
NextObservationDelta, a stateless env-side transform that stores("next", obs)as a low-precision delta from the rootobsfor rollout memory savings on large continuous observations._post_step_mdp_hooksextension point inEnvBase.step_and_maybe_resetand threads it throughTransform,Compose, andTransformedEnv. The hook receives both the post-step and post-step-mdp tensordicts so a transform can rehydrate the flowing td that the policy reads on the next iteration.NextObservationDelta._stepwrites(next_obs - obs).to(delta_dtype)(defaultfloat16);_post_step_mdp_hooksreconstructsobs + deltainrestore_dtype(default: match root). Stateless — no caching across steps.Why this shape
The existing
compact_obscollector flag +NextStateReconstructorRB transform attack the same problem by dropping("next", obs)entirely and shifting at sample time. That is zero-storage but lossy at trajectory boundaries (which becomeNaN). The delta variant trades a small precision loss for boundary-preserving reconstruction and an env-side hook that does not need to know about collector internals.The
_post_step_mdp_hooksmechanism was already stubbed (commented out) incommon.py,transforms/_base.py, andllm/chat.py. This PR enables it. The signature was changed from the original comment ((tensordict_,) -> tensordict_) to(tensordict, tensordict_) -> tensordict_because rehydration needs read access to the post-step root obs. No caller existed before, so this is not a breaking change.v1 limitations (documented on the class)
delta_dtypeprecision and observation magnitude.SyncDataCollector(use_buffers=False)or a lazy RB storage. Pre-allocated_final_rolloutupcasts the write back to the original dtype and erases the saving.step_and_maybe_resetonly.env.rollout()is not wired in v1; direct rollout callers must rehydrate manually.check_env_specsdoes not pass on the transformed env.observation_specis shared between root and("next", ...)in TorchRL; the transform does not fork it in v1 (a follow-up could). Tests use a reset+step smoke instead.SerialEnv/ParallelEnv, the transform belongs outside the batched env (i.e.TransformedEnv(ParallelEnv(...), NextObservationDelta())) — that path uses the outerstep_and_maybe_resetand the hook fires. Putting the transform inside each worker is allowed and runs without error, but the outer batched env'sstep_and_maybe_resetdoes not currently propagate the hook so the stacked output upcasts.Out of scope (potential follow-ups)
observation_specso pre-allocated_final_rolloutbenefits from the compression._rollout_stop_earlyand inbatched_envs/async_envs/envpoolstep_and_maybe_reset.benchmarks/.Test plan
pytest test/transforms/test_observation_transforms.py::TestNextObservationDelta— 14 passed, 2 documented skips.pytest --doctest-modules torchrl/envs/transforms/_observation.py -k NextObservationDelta— passes.pytest test/envs/test_env_base.py— 47 passed, 4 skipped (no regressions from the hook wiring).GymEnv("Pendulum-v1")confirms("next", "observation").dtype == torch.float16post-step andtorch.float32on the flowing td, with bitwise-exact rehydration (max diff 0.0).Compose(NextObservationDelta, RewardSum)works in both orderings.